import torch
import torch.nn as nn
from torchvision import transforms
from PIL import Image
import cv2
import numpy as np
import requests
from urllib.request import urlopen
import imutils
from model import ArtNet
import torch.nn.functional as nnf
import time

from utils import *


model = ArtNet(11)
model_path = "/home/shivam-wiz/Downloads/MLPR___/Trial/best_checkpoint.model"    
model.load_state_dict(torch.load(model_path))
model.eval()

# Replace 'your_ip_address' and 'your_port' with the actual IP address and port of your IP webcam
url = "http://10.1.16.202:8080/shot.jpg"

def preprocess_image(image_path):
    # Load and preprocess the image
    image = Image.open(image_path)
    input_tensor = transformer(image)
    input_batch = input_tensor.unsqueeze(0)  # Add a batch dimension
    return input_batch

# While loop to continuously fetching data from the Url
while True:
    img_resp = requests.get(url)
    img_arr = np.array(bytearray(img_resp.content), dtype=np.uint8)
    img = cv2.imdecode(img_arr, -1)
    new_img = imutils.resize(img, width=1000, height=1800)

    cv2.imwrite("input_image.jpg", new_img)

    with torch.no_grad():
        input_image = preprocess_image("input_image.jpg")
        output = model(input_image)
        probabilities = torch.nn.functional.softmax(output[0], dim=0)
        predicted_class = torch.argmax(probabilities).item()

    # Assuming you have a list of class labels
    class_labels = classes
    predicted_label = class_labels[predicted_class]
    if probabilities[predicted_class] > 0.8:
        print(f"The model predicts: {predicted_label} with confidence: {probabilities[predicted_class]:.2%}")

    cv2.imshow("Android_cam", new_img)

    # Press Esc key to exit
    if cv2.waitKey(1) == 27:
        break


cv2.destroyAllWindows()
